from hcquant.calendar import adjust_to_trade_date
from dateutil.parser import parse
import datetime as dts
import pandas as pd
from numpy import nan
# sid 收益以及基准比较

class Stats:
    def __init__(self, engine):
        self.engine = engine

    
    def ret_compare_benchmark(self, sid, begin_date, benchmark='hs300', period=30):
        """
        对对应的sid 返回period期间内的持仓收益，并与基准比较，计算相对收益
        """
        if isinstance(sid,str):
            if len(sid) == 11:
                sid = sid[:7] + 'SH' if sid[:6] >= '600000' else sid[:7] + 'SZ'
            buylist = (sid, )
        else:
            if len(sid[0]) == 11:
                buylist = tuple((x[:7] + 'SH' if x[:6] >= '600000' \
                                else x[:7] + 'SZ' for x in sid))
            else:
                buylist = tuple(sid)
        begin_date = adjust_to_trade_date(begin_date, adjust='last')
        end_date = adjust_to_trade_date(parse(begin_date) + dt.timedelta(period), adjust='last')
        sql1 = f"""
        select TRADE_DT trade_dt, S_INFO_WINDCODE sid, S_DQ_ADJCLOSE adjclose from nWind.dbo.AShareEODPrices
        where ((TRADE_DT = '{begin_date}') or (TRADE_DT = '{end_date}'))
        and S_INFO_WINDCODE in {buylist}
        ORDER by S_INFO_WINDCODE, TRADE_DT
        """
        def _group(x):
            try:
                value = x['adjclose'].pct_change().iloc[1]
                if value > 0.5:
                    print(begin_date)
                return value
            except IndexError:
                return nan
        df1 = pd.read_sql(sql1, self.engine)
        sid_pct_change = df1.groupby('sid').apply(_group)\
                .reset_index().rename(columns={0: 'ret'})
        sid_pct_change.set_index('sid', inplace=True)
        sid_pct_change = sid_pct_change.loc[buylist, :]
        sid_pct_change.reset_index(inplace=True)
        if benchmark == 'hs300':
            benchmark = '000300.SH'
        elif benchmark == 'zz500':
            benchmark = '000905.SH'
        sql2 = f"""
        select S_DQ_CLOSE adjclose from nWind.dbo.AINDEXEODPRICES
        where S_INFO_WINDCODE = '{benchmark}'
        and ((TRADE_DT = '{begin_date}') or (TRADE_DT = '{end_date}'))
        ORDER BY TRADE_DT
        """
        df2 = pd.read_sql(sql2, self.engine)
        benchmark_pct_change = df2['adjclose'].pct_change().iloc[1]
        sid_pct_change['relative_ret'] = sid_pct_change['ret'] - benchmark_pct_change
        return sid_pct_change